import torch
import numpy as np

from dada.model.torch_model import TorchModel


class WorstCaseModel(TorchModel):
    def __init__(self,
                 n_features: int,
                 q: float,
                 init_point: torch.Tensor = None):

        self.q = q
        super().__init__(n_features, init_point)

    def loss(self):
        differences = torch.abs(self.x[:-1] - self.x[1:]) ** self.q
        summation_term = torch.sum(differences)
        final_term = torch.abs(self.x[-1]) ** self.q

        result = (1 / self.q) * summation_term + (1 / self.q) * final_term

        return result

    def compute_value(self, point: torch.Tensor):
        differences = np.abs(point[:-1] - point[1:]) ** self.q
        summation_term = np.sum(differences)
        final_term = np.abs(point[-1]) ** self.q

        result = (1 / self.q) * summation_term + (1 / self.q) * final_term

        return result
